package com.yeskery.nut.scan.bean;

import com.yeskery.nut.annotation.aop.*;
import com.yeskery.nut.aop.DefaultProxyObjectContext;
import com.yeskery.nut.aop.ProxyObjectContext;
import com.yeskery.nut.aop.aspect.AspectAdvice;
import com.yeskery.nut.aop.aspect.AspectFactory;
import com.yeskery.nut.aop.aspect.BeanJoinPoint;
import com.yeskery.nut.aop.aspect.JoinPoint;
import com.yeskery.nut.bean.ApplicationContext;
import com.yeskery.nut.bean.BaseApplicationContext;
import com.yeskery.nut.core.NutException;
import com.yeskery.nut.scan.AnnotationHandler;
import com.yeskery.nut.scan.BeanAnnotationScanMetadata;
import com.yeskery.nut.util.ReflectUtils;
import com.yeskery.nut.util.StringUtils;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

/**
 * 切面注解处理器
 * @author Yeskery
 * 2023/8/7
 */
public class AnnotationAspectHandler implements AnnotationHandler {

    /** 代理对象上下文 */
    private final ProxyObjectContext proxyObjectContext;

    /** 切面工厂 */
    private final AspectFactory aspectFactory = new AspectFactory();

    /**
     * 构建切面注解处理器
     * @param applicationContext 应用上下文
     */
    public AnnotationAspectHandler(ApplicationContext applicationContext) {
        this.proxyObjectContext = ((BaseApplicationContext) applicationContext).getProxyObjectContext();
    }

    @Override
    public void handle(Collection<Class<?>> beanClassCollection, Collection<BeanAnnotationScanMetadata> beanMetadataCollection) {
        Map<String, Collection<AspectAdvice>> cacheMap = ((DefaultProxyObjectContext) proxyObjectContext).getPoxyMethodAspectCacheMap();
        Map<String, JoinPoint> joinPointMap = new HashMap<>(64);
        for (Class<?> clazz : beanClassCollection) {
            for (Method method : ReflectUtils.getBeanAnnotationMethod(clazz, Pointcut.class)) {
                Pointcut pointcut = method.getAnnotation(Pointcut.class);
                String express = pointcut.value();
                if (StringUtils.isEmpty(express)) {
                    throw new NutException("Pointcut Express Must Not Be empty.");
                }
                joinPointMap.put(method.getName(), aspectFactory.getJoinPoint(express));
            }
            buildMethodProxyAspectCache(beanMetadataCollection, clazz, Before.class,
                    a -> getPointcut(((Before) a).pointcut(), ((Before) a).value()), joinPointMap, cacheMap);
            buildMethodProxyAspectCache(beanMetadataCollection, clazz, After.class,
                    a -> getPointcut(((After) a).pointcut(), ((After) a).value()), joinPointMap, cacheMap);
            buildMethodProxyAspectCache(beanMetadataCollection, clazz, Around.class,
                    a -> getPointcut(((Around) a).pointcut(), ((Around) a).value()), joinPointMap, cacheMap);
            buildMethodProxyAspectCache(beanMetadataCollection, clazz, Throwing.class,
                    a -> getPointcut(((Throwing) a).pointcut(), ((Throwing) a).value()), joinPointMap, cacheMap);
        }
    }

    /**
     * 构建方法的代理切面缓存
     * @param beanMetadataCollection bean元数据集合
     * @param clazz clazz对象
     * @param annotationClass 注解class
     * @param pointcutFunction 连接点名称函数
     * @param joinPointMap 连接点map
     * @param cacheMap 缓存map
     */
    private void buildMethodProxyAspectCache(Collection<BeanAnnotationScanMetadata> beanMetadataCollection,
                                             Class<?> clazz, Class<? extends Annotation> annotationClass,
                                             Function<Annotation, String> pointcutFunction, Map<String, JoinPoint> joinPointMap,
                                             Map<String, Collection<AspectAdvice>> cacheMap) {
        for (Method method : ReflectUtils.getBeanAnnotationMethod(clazz, annotationClass)) {
            Annotation annotation = method.getAnnotation(annotationClass);
            String pointcut = pointcutFunction.apply(annotation);
            if (StringUtils.isEmpty(pointcut)) {
                throw new NutException("Class[" + clazz.getName() + "] Method[" + method + "] Need Pointcut Value.");
            }
            if (pointcut.endsWith("()")) {
                pointcut = pointcut.substring(0, pointcut.length() - 2);
            }
            JoinPoint joinPoint = joinPointMap.get(pointcut);
            if (joinPoint == null) {
                throw new NutException("Current Class[" + clazz.getName() + "] Can Not Found JoinPoint Method [" + pointcut + "]");
            }
            for (BeanAnnotationScanMetadata metadata : beanMetadataCollection) {
                // 切面配置类无需再次设置切面
                if (metadata.getSource() == BeanAnnotationScanMetadata.Source.ASPECT) {
                    continue;
                }
                for (Method beanMethod : ReflectUtils.getBeanMethods(metadata.getType())) {
                    boolean hit;
                    if (joinPoint instanceof BeanJoinPoint) {
                        hit = ((BeanJoinPoint) joinPoint).isTargetMethod(metadata, beanMethod);
                    } else {
                        hit = joinPoint.isTargetMethod(beanMethod);
                    }
                    if (hit) {
                        Collection<AspectAdvice> aspectAdvices = cacheMap.computeIfAbsent(beanMethod.toString(), k -> new ArrayList<>());
                        AspectAdvice aspectAdvice = new AspectAdvice();
                        aspectAdvice.setAnnotation(annotation);
                        aspectAdvice.setMethod(method);
                        aspectAdvice.setJoinPoint(joinPoint);
                        aspectAdvice.setBeanName(ReflectUtils.getDefaultBeanName(clazz));
                        aspectAdvice.setBeanType(clazz);
                        aspectAdvices.add(aspectAdvice);
                    }
                }
            }
        }
    }

    /**
     * 获取连接点名称
     * @param pointcut 连接点属性名称
     * @param value value属性名称
     * @return 连接点名称
     */
    private String getPointcut(String pointcut, String value) {
        return StringUtils.isEmpty(pointcut) ? value : pointcut;
    }
}
